//! rcu 读写锁
//!
//! 基于 qemu 实现的 rcu 读写锁, 读端基本无锁, 写端需要调用者维护一致性.
//!
//! 注意:
//!
//! 用户态无法保证线程不被抢占, 因此宽限期的结束只能推迟到所有读端全部退出,
//! 这可能引起资源释放的滞后.
#![feature(lazy_cell)]
#![feature(hash_extract_if)]
#![feature(negative_impls)]
#![allow(unsafe_code)]
#![allow(clippy::linkedlist)]
#![allow(clippy::unnecessary_wraps)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::missing_errors_doc)]

use std::{
    cell::{RefCell, UnsafeCell},
    collections::{HashMap, LinkedList},
    fmt,
    ops::{Deref, DerefMut},
    ptr::NonNull,
    sync::{
        atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering},
        Arc, Condvar, LazyLock, LockResult, Mutex, MutexGuard, TryLockError, TryLockResult,
    },
    thread::{self, JoinHandle, ThreadId},
    time::Duration,
};

use barrier::{smp_mb, smp_mb_global, smp_mb_placeholder};

// 本地线程内部变量, 用于 rcu 线程获取线程 rcu 信息
// * ctr: 当前线程宽限期标记
// * waiting: rcu 线程是否在等待 rcu 解锁
// * force_rcu: 当前线程注册的通知回调, 当强制执行 rcu 时调用
struct RcuReaderDataInter {
    ctr: AtomicU64,
    waiting: AtomicBool,
    force_rcu: Mutex<LinkedList<Box<dyn Fn() + Send>>>,
}

// 本地线程变量, 标记一个线程的 rcu 状态
// * data: rcu 内部数据, rcu 线程和本地线程共同使用
// * depth: 当前线程的 rcu 锁深度
struct RcuReaderData {
    data: Arc<RcuReaderDataInter>,
    depth: u32,
}

thread_local! {
    // 本地线程的局部变量, RcuReaderData 的实例化
    static RCU_READER: RefCell<RcuReaderData> = RefCell::new(RcuReaderData {
        data: Arc::new(RcuReaderDataInter {
            ctr: AtomicU64::new(0),
            waiting: AtomicBool::new(false),
            force_rcu: Mutex::new(LinkedList::new()),
        }),
        depth: 0,
    });
}

// 全局 RcuReaderDataInter 表, 当一个线程注册使用 rcu 时, 会将 RcuReaderDataInter 放入该表,
// rcu 线程访问 RcuReaderDataInter 与工作线程通信
static RCU_REGISTRY_LOCK: LazyLock<Mutex<HashMap<ThreadId, Arc<RcuReaderDataInter>>>> =
    LazyLock::new(|| Mutex::new(HashMap::new()));

static DEPTH_COUNT: LazyLock<Mutex<HashMap<ThreadId, usize>>> =
    LazyLock::new(|| Mutex::new(HashMap::new()));

/// 注册当前线程到 rcu 管理表中
///
/// 每一个要使用 rcu 锁的线程都需要调用该函数注册 rcu 表,
/// 否则线程不能使用 rcu 锁.
///
/// 注意:
///
/// 当线程退出时需要手动调用 `rcu_unregister_thread` 来卸载 rcu 数据
pub fn rcu_register_thread() {
    RCU_READER.with(|f| {
        let gp = f.borrow();
        let inter = gp.data.clone();
        debug_assert!(inter.ctr.load(Ordering::Relaxed) == 0);
        let mut lock = RCU_REGISTRY_LOCK.lock().unwrap();
        lock.insert(thread::current().id(), inter);
        let mut lock = DEPTH_COUNT.lock().unwrap();
        let ptr = std::ptr::addr_of!(gp.depth) as usize;
        lock.insert(thread::current().id(), ptr);
    });
}

/// 从 rcu 管理向量表中卸载当前线程
pub fn rcu_unregister_thread() {
    RCU_READER.with(|_| {
        let mut lock = RCU_REGISTRY_LOCK.lock().unwrap();
        lock.remove(&thread::current().id()).unwrap();
        let mut lock = DEPTH_COUNT.lock().unwrap();
        lock.remove(&thread::current().id()).unwrap();
    });
}

/// 向当前线程的 rcu 注册一个通知
///
/// 当强制执行 `call_rcu` 相关函数时, 注册的通知将被调用
///
/// 不需要 `rcu_remove_force_rcu_notifier`,
/// 它会自动在线程生命周期结束时释放资源.
pub fn rcu_add_force_rcu_notifier<F>(notify: F)
where
    F: Fn() + Send + 'static,
{
    RCU_READER.with(|f| {
        let gp = f.borrow_mut();
        let b = Box::new(notify);
        let mut n = gp.data.force_rcu.lock().unwrap();
        n.push_back(b);
    });
}

const RCU_GP_LOCKED: u32 = 1;
const RCU_GP_CTR_G: u64 = 2;

static RCU_GP_CTR: AtomicU64 = AtomicU64::new(RCU_GP_LOCKED as u64);
static RCU_GP_COND: Mutex<bool> = Mutex::new(true);
static RCU_GP_EVENT: Condvar = Condvar::new();

fn rcu_read_lock() {
    RCU_READER.with(|f| {
        let mut gp = f.borrow_mut();
        gp.depth += 1;
        if gp.depth > 1 {
            return;
        }

        let ctr = RCU_GP_CTR.load(Ordering::Relaxed);
        gp.data.ctr.store(ctr, Ordering::Relaxed);

        smp_mb_placeholder();
    });
}

fn rcu_read_unlock() {
    RCU_READER.with(|f| {
        let mut gp = f.borrow_mut();

        assert!(gp.depth != 0);
        gp.depth -= 1;
        if gp.depth > 0 {
            return;
        }

        gp.data.ctr.store(0, Ordering::Release);

        smp_mb_placeholder();
        if gp.data.waiting.load(Ordering::Relaxed) {
            gp.data.waiting.store(false, Ordering::Relaxed);
            RCU_GP_EVENT.notify_one();
        }
    });
}

static RCU_GP_HEAD: Mutex<LinkedList<Box<dyn FnOnce() + Send>>> = Mutex::new(LinkedList::new());

const RCU_CALL_MIN_SIZE: u32 = 30;
static RCU_CALL_COUNT: AtomicU32 = AtomicU32::new(0);

static RCU_CALL_READY_EVENT_COND: Mutex<bool> = Mutex::new(true);
static RCU_CALL_READY_EVENT: Condvar = Condvar::new();

/// rcu 宽限期结束后调用回调
///
/// 当一个宽限期结束时, 执行 rcu 回调
pub fn call_rcu<F>(f: F)
where
    F: FnOnce() + Send + 'static,
{
    let value = Box::new(f);
    let mut lock = RCU_GP_HEAD.lock().unwrap();
    lock.push_front(value);
    RCU_CALL_COUNT.fetch_add(1, Ordering::SeqCst);
    RCU_CALL_READY_EVENT.notify_one();
}

static IN_DRAIN_CALL_RCU: AtomicU32 = AtomicU32::new(0);

/// 强制调用 rcu 回调
///
/// 在同步等待 rcu 结束后使用
pub fn drain_call_rcu() {
    let event = Arc::new((Mutex::new(true), Condvar::new()));

    IN_DRAIN_CALL_RCU.fetch_add(1, Ordering::SeqCst);
    let ev2 = event.clone();
    call_rcu(move || {
        ev2.1.notify_one();
    });
    let _ = event.1.wait(event.0.lock().unwrap());
    IN_DRAIN_CALL_RCU.fetch_sub(1, Ordering::SeqCst);
}

// 返回 false 说明可以调用 rcu 回调
#[inline]
fn rcu_gp_ongoing(ctr: &AtomicU64) -> bool {
    let v = ctr.load(Ordering::Relaxed);
    v != 0 && v != RCU_GP_CTR.load(Ordering::Relaxed)
}

static RCU_SYNC_LOCK: Mutex<bool> = Mutex::new(false);

/// 同步等待 rcu 宽限期结束
pub fn synchronize_rcu() {
    let _lock = RCU_SYNC_LOCK.lock().unwrap();
    let mut sync = false;

    smp_mb_global();

    loop {
        {
            let mut registry = RCU_REGISTRY_LOCK.lock().unwrap();
            if registry.is_empty() {
                return;
            }

            if !sync {
                RCU_GP_CTR.fetch_add(RCU_GP_CTR_G, Ordering::SeqCst);
                smp_mb();
                sync = true;
            }

            for index in registry.iter_mut() {
                index.1.waiting.store(true, Ordering::Relaxed);
            }

            smp_mb_global();

            let qsreaders = registry
                .extract_if(|_, v| {
                    if !rcu_gp_ongoing(&v.ctr) {
                        // 需要回调的数据暂时取出
                        v.waiting.store(false, Ordering::Relaxed);
                        true
                    } else if IN_DRAIN_CALL_RCU.load(Ordering::Relaxed) != 0 {
                        let n = v.force_rcu.lock().unwrap();
                        for notify in n.iter() {
                            notify();
                        }
                        false
                    } else {
                        false
                    }
                })
                .collect::<HashMap<_, _>>();

            // 等待所有线程经历静默期
            if registry.is_empty() {
                registry.extend(qsreaders);
                return;
            }
        }
        let _ = RCU_GP_EVENT.wait(RCU_GP_COND.lock().unwrap());
    }
}

/// rcu 同步管理线程
///
/// 调用该函数后, rcu 同步机制启动并开始正常运行
///
/// 注意:
///
/// 该函数只有主线程调用一次
pub fn soft_rcu_init() -> JoinHandle<()> {
    rcu_register_thread();

    thread::Builder::new()
        .name("rcu_gp".into())
        .spawn(|| {
            rcu_register_thread();
            loop {
                let mut tries = 0;
                let mut n = RCU_CALL_COUNT.load(Ordering::Relaxed);
                tries += 1;
                while (n == 0) || (n < RCU_CALL_MIN_SIZE && tries <= 5) {
                    thread::sleep(Duration::from_millis(10));
                    if n == 0 {
                        n = RCU_CALL_COUNT.load(Ordering::Relaxed);
                        if n == 0 {
                            let _ = RCU_CALL_READY_EVENT
                                .wait(RCU_CALL_READY_EVENT_COND.lock().unwrap());
                        }
                    }
                    n = RCU_CALL_COUNT.load(Ordering::Relaxed);
                    tries += 1;
                }

                RCU_CALL_COUNT.fetch_sub(n, Ordering::SeqCst);
                synchronize_rcu();
                while n > 0 {
                    let node;
                    {
                        let mut head = RCU_GP_HEAD.lock().unwrap();
                        node = head.pop_back();
                    }
                    if node.is_none() {
                        let n2;
                        {
                            let mut head = RCU_GP_HEAD.lock().unwrap();
                            n2 = head.pop_back();
                        }
                        if n2.is_none() {
                            let _ = RCU_CALL_READY_EVENT
                                .wait(RCU_CALL_READY_EVENT_COND.lock().unwrap());
                        }
                    } else if let Some(f) = node {
                        n -= 1;
                        f();
                    }
                }
            }
        })
        .unwrap()
}

/// rcu 锁
#[derive(Default)]
pub struct RcuLock<T: ?Sized> {
    inner: Mutex<bool>,
    data: UnsafeCell<T>,
}

unsafe impl<T: ?Sized + Send> Send for RcuLock<T> {}
unsafe impl<T: ?Sized + Send + Sync> Sync for RcuLock<T> {}

/// rcu 读保护
pub struct RcuLockReadGuard<T: ?Sized> {
    data: NonNull<T>,
}

impl<T: ?Sized> !Send for RcuLockReadGuard<T> {}
unsafe impl<T: ?Sized + Sync> Sync for RcuLockReadGuard<T> {}

/// rcu 写保护
pub struct RcuLockWriteGuard<'a, T: ?Sized + 'a> {
    lock: &'a RcuLock<T>,
    #[allow(unused)]
    data: MutexGuard<'a, bool>,
}

impl<T: ?Sized> !Send for RcuLockWriteGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for RcuLockWriteGuard<'_, T> {}

impl<T> RcuLock<T> {
    /// 创建一个 rcu lock
    #[inline]
    pub const fn new(t: T) -> RcuLock<T> {
        RcuLock { inner: Mutex::new(false), data: UnsafeCell::new(t) }
    }
}

impl<T: ?Sized> RcuLock<T> {
    /// rcu 进入读临界区
    #[inline]
    pub fn read(&self) -> LockResult<RcuLockReadGuard<T>> {
        rcu_read_lock();
        unsafe { RcuLockReadGuard::new(self) }
    }

    /// rcu 进入写临界区
    #[inline]
    pub fn write(&self) -> LockResult<RcuLockWriteGuard<'_, T>> {
        let write_lock = self.inner.lock().unwrap();
        unsafe { RcuLockWriteGuard::new(self, write_lock) }
    }

    /// rcu 尝试进入写临界区
    #[inline]
    pub fn try_write(&self) -> TryLockResult<RcuLockWriteGuard<'_, T>> {
        let write_data = self.inner.try_lock();
        match write_data {
            Ok(write_lock) => Ok(RcuLockWriteGuard { lock: self, data: write_lock }),
            Err(_) => Err(TryLockError::WouldBlock),
        }
    }

    /// 解包值
    pub fn into_inner(self) -> LockResult<T>
    where
        T: Sized,
    {
        Ok(self.data.into_inner())
    }

    /// 返回该值的可变引用
    pub fn get_mut(&mut self) -> LockResult<&mut T> {
        Ok(self.data.get_mut())
    }
}

impl<T: ?Sized + fmt::Debug> fmt::Debug for RcuLock<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let mut d = f.debug_struct("RcuLock");
        RCU_READER.with(|f| {
            let reader = f.borrow();
            d.field("depth", &&reader.depth);
            d.field("ctr", &&reader.data.ctr.load(Ordering::Relaxed));
            d.field("waiting", &&reader.data.waiting.load(Ordering::Relaxed));
            d.finish_non_exhaustive()
        })
    }
}

impl<T> From<T> for RcuLock<T> {
    fn from(value: T) -> Self {
        RcuLock::new(value)
    }
}

impl<'rculock, T: ?Sized> RcuLockReadGuard<T> {
    unsafe fn new(lock: &'rculock RcuLock<T>) -> LockResult<RcuLockReadGuard<T>> {
        Ok(RcuLockReadGuard { data: NonNull::new_unchecked(lock.data.get()) })
    }
}

impl<'rculock, T: ?Sized> RcuLockWriteGuard<'rculock, T> {
    unsafe fn new(
        lock: &'rculock RcuLock<T>,
        data: MutexGuard<'rculock, bool>,
    ) -> LockResult<RcuLockWriteGuard<'rculock, T>> {
        Ok(RcuLockWriteGuard { lock, data })
    }
}

impl<T: fmt::Debug> fmt::Debug for RcuLockReadGuard<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        (**self).fmt(f)
    }
}

impl<T: ?Sized + fmt::Display> fmt::Display for RcuLockReadGuard<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        (**self).fmt(f)
    }
}

impl<T: fmt::Debug> fmt::Debug for RcuLockWriteGuard<'_, T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        (**self).fmt(f)
    }
}

impl<T: ?Sized + fmt::Display> fmt::Display for RcuLockWriteGuard<'_, T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        (**self).fmt(f)
    }
}

impl<T: ?Sized> Deref for RcuLockReadGuard<T> {
    type Target = T;

    fn deref(&self) -> &T {
        unsafe { self.data.as_ref() }
    }
}

impl<T: ?Sized> Deref for RcuLockWriteGuard<'_, T> {
    type Target = T;

    fn deref(&self) -> &T {
        // SAFETY: the conditions of `RwLockWriteGuard::new` were satisfied when created.
        unsafe { &*self.lock.data.get() }
    }
}

impl<T: ?Sized> DerefMut for RcuLockWriteGuard<'_, T> {
    fn deref_mut(&mut self) -> &mut T {
        // SAFETY: the conditions of `RwLockWriteGuard::new` were satisfied when created.
        unsafe { &mut *self.lock.data.get() }
    }
}

impl<T: ?Sized> Drop for RcuLockReadGuard<T> {
    fn drop(&mut self) {
        rcu_read_unlock();
    }
}

/// rcu 信息记录
///
/// 通过线程 id 获取该线程的 rcu 信息, 包括调用深度, 宽限期标记,
/// 是否等待 rcu 宽限期结束, 最后通过字符串返回
///
/// # Panics
/// 如果`ThreadId`对应的线程没有通过`rcu_register_thread`注册到 rcu 表中,
/// 那么调用该函数将会导致 panic
pub fn rcu_info(id: &ThreadId) -> Option<String> {
    let lock = DEPTH_COUNT.lock().unwrap();
    let depth;
    unsafe {
        depth = (*lock.get(id)? as *const u32).as_ref().unwrap();
    }
    let lock = RCU_REGISTRY_LOCK.lock().unwrap();
    let data = lock.get(id)?;
    let ctr = data.ctr.load(Ordering::Relaxed);
    let waiting = data.waiting.load(Ordering::Relaxed);

    let str = format!("depth: {}, ctr: {}, waiting: {}", depth, ctr, waiting);
    Some(str)
}
